Skip to content

ggml-cuda: add mem check for fusion#19916

Merged
am17an merged 3 commits intoggml-org:masterfrom
am17an:cuda_add_memcheck
Mar 6, 2026
Merged

ggml-cuda: add mem check for fusion#19916
am17an merged 3 commits intoggml-org:masterfrom
am17an:cuda_add_memcheck

Conversation

@am17an
Copy link
Contributor

@am17an am17an commented Feb 26, 2026

Fixes #19659

@github-actions github-actions bot added Nvidia GPU Issues specific to Nvidia GPUs ggml changes relating to the ggml tensor library for machine learning labels Feb 26, 2026
@am17an am17an marked this pull request as ready for review February 28, 2026 05:47
Comment on lines +122 to +132
// Sanitize NaN to -FLT_MAX so the iterative argmax produces unique expert IDs.
// NaN comparisons always return false, which would cause the same expert to be
// selected repeatedly. -FLT_MAX compares normally and is still excluded by the
// -INFINITY sentinel used after each selection round.
// More relevant for the cuBLAS path. See https://github.com/ggml-org/llama.cpp/issues/19659
#pragma unroll
for (int i = 0; i < experts_per_thread; i++) {
if (__isnanf(wt[i])) {
wt[i] = -FLT_MAX;
}
}
Copy link
Collaborator

@ORippler ORippler Mar 2, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. If the issue is in llama.cpp and not cuBLAS, I feel we should use fmaxf as a NaN-safe comparator: https://docs.nvidia.com/cuda/cuda-math-api/cuda_math_api/group__CUDA__MATH__SINGLE.html#_CPPv45fmaxfff (I presume we are talking about val_s > max_val_s later on in this kernel?)
  2. If the issue is in cuBLAS, I'd love more details so I can ask the cuBLAS team/take a look myself

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. Yes but it's not just val_s > max_val_s it's val_s > max_val_s || (val_s == max_val_s && expert < max_expert)
  2. The linked issue has a repro. It's cuBLAS + Nemotron, so think it would be fun for you guys to look at :)

Copy link
Collaborator

@ORippler ORippler Mar 2, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes but it's not just val_s > max_val_s it's val_s > max_val_s || (val_s == max_val_s && expert < max_expert)

Shouldn't we be fine with fmaxf, so long as max_val & max_val_s are initialized to -FLT_MAX instead of -INFINITY at the beginning of the selection-loop over n_expert_used? At least for the case where k non-NAN values exist inside the logits for a given row. But at this point we are just pulling your proposal into the loop itself 😄

@am17an
Copy link
Contributor Author

am17an commented Mar 5, 2026

@JohannesGaessler can you review this PR? Apart from the NaN check it also fixes a latent bug

Comment on lines +3382 to +3386
if ((b_start <= a_start && a_start < b_end) || (a_start <= b_start && b_start < a_end)) {
return true;
}

return false;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if ((b_start <= a_start && a_start < b_end) || (a_start <= b_start && b_start < a_end)) {
return true;
}
return false;
return (b_start <= a_start && a_start < b_end) || (a_start <= b_start && b_start < a_end);

This would maybe be slightly simpler but either way is fine I think.

Co-authored-by: Johannes Gäßler <johannesg@5d6.de>
@am17an am17an merged commit d48e876 into ggml-org:master Mar 6, 2026
73 of 75 checks passed
@am17an am17an deleted the cuda_add_memcheck branch March 6, 2026 16:05
bartowski1182 pushed a commit to bartowski1182/llama.cpp that referenced this pull request Mar 10, 2026
* ggml-cuda: add mem check for fusion

* Replace NaNs with -FLT_MAX

* fix typo

Co-authored-by: Johannes Gäßler <johannesg@5d6.de>

---------

Co-authored-by: Johannes Gäßler <johannesg@5d6.de>
Ethan-a2 pushed a commit to Ethan-a2/llama.cpp that referenced this pull request Mar 20, 2026
* ggml-cuda: add mem check for fusion

* Replace NaNs with -FLT_MAX

* fix typo

Co-authored-by: Johannes Gäßler <johannesg@5d6.de>

---------

Co-authored-by: Johannes Gäßler <johannesg@5d6.de>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ggml changes relating to the ggml tensor library for machine learning Nvidia GPU Issues specific to Nvidia GPUs

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Eval bug: [CUDA, cuBLAS] Corrupted output on CUBLAS with moe models like Nemotron-3-nano and gpt-oss-120b with long context preprocessing

3 participants